Skip to content

[llm] support tensorwise fp8/int8 training #10612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from

Conversation

lugimzzz
Copy link
Contributor

@lugimzzz lugimzzz commented May 19, 2025

PR types

New features

PR changes

APIs

Description

新增支持功能:
1.新增权重scale和激活scale all_reduce_max,以支持不同TP和数据并行策略切分
2. 支持DP+TP+PP+Sharding stage1训练FP8/INT8训练,使用Unified Checkpoint对权重、optimizer存储
3. 哈达玛矩阵乘改用对角block 哈达玛矩阵
4. 统一FP8/INT8训练代码逻辑
5. 新增支持Triton版本FP8权重AdamW优化器(含bf16 moment和offload功能)
6. 支持主干模型FP8/INT8 LoRA

后续PR待支持功能:
1.目前FP8权重使用paddle.int8表示np.int8存储,后续修改为float8表示(待框架支持fp8 set_value和concat)
2. 对FP8/INT8 quant-matmul-dequant 过程进行性能加速和对Moe结构进行加速适配
3.FP8/INT8训练支持Sharding stage2/3(PP仅支持stage1 优先级不高)

image

Copy link

paddle-bot bot commented May 19, 2025

Thanks for your contribution!

Copy link

codecov bot commented May 19, 2025

Codecov Report

Attention: Patch coverage is 16.57895% with 317 lines in your changes missing coverage. Please review.

Project coverage is 46.90%. Comparing base (23e4c1a) to head (7935b87).
Report is 6 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/quantization/qat_utils.py 7.69% 72 Missing ⚠️
paddlenlp/utils/optimizer.py 7.69% 72 Missing ⚠️
paddlenlp/utils/adamw_triton.py 13.84% 56 Missing ⚠️
paddlenlp/transformers/conversion_utils.py 9.37% 29 Missing ⚠️
paddlenlp/transformers/model_utils.py 21.62% 29 Missing ⚠️
paddlenlp/quantization/quantization_linear.py 10.34% 26 Missing ⚠️
paddlenlp/quantization/hadamard_utils.py 16.66% 25 Missing ⚠️
paddlenlp/quantization/quantization_utils.py 25.00% 3 Missing ⚠️
paddlenlp/trainer/trainer.py 60.00% 2 Missing ⚠️
paddlenlp/utils/distributed.py 0.00% 2 Missing ⚠️
... and 1 more

❌ Your patch check has failed because the patch coverage (16.57%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (46.90%) is below the target coverage (58.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop   #10612      +/-   ##
===========================================
- Coverage    46.91%   46.90%   -0.02%     
===========================================
  Files          799      800       +1     
  Lines       132460   132519      +59     
===========================================
+ Hits         62148    62157       +9     
- Misses       70312    70362      +50     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lugimzzz lugimzzz changed the title add uc [llm] support tensorwise fp8 training May 21, 2025
@@ -478,8 +525,8 @@ def load_state_dict(
scale_dict.update(res_scale_dict)

if device == "cpu":
for k in list(state_dict.keys()):
with device_guard():
with device_guard():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

减少反复set_device增加耗时

"weight_only_int4",
"weight_only_int8",
]
elif isinstance(config.quantization_config.weight_quantize_algo, dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_only_int8不支持不同TP分片共享同一个scale,暂不支持wint8权重灵活转化TP策略

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

post_quantize 代表先TP切分权重再量化(针对wint4/wint8)

@@ -2537,6 +2615,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
# load pt weights early so that we know which dtype to init the model under
if not is_sharded and state_dict is None:
# 4. loading non-sharded ckpt from the state dict
# Quantization: Loading non-sharded ckpt does not support saving with merge_tensor_parallel
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不考虑非safetensor权重的量化加载和保存

@lugimzzz lugimzzz changed the title [llm] support tensorwise fp8 training [llm] support tensorwise fp8/int8 training May 21, 2025
return block


def create_hadamard_matrix(block_size, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和前面random_hadamard_matrix的区别是什么

if getattr(infohub, "hadamard") is None:
setattr(infohub, "hadamard", {})

if block_size in infohub.hadamard:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hadamard_matrix 没有默认值的话,没有命中该分支会出问题

@@ -2107,16 +2109,6 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:

optimizer_cls = AdamWCustom
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAMW_16BIT_MOMENT:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要去掉这两种adamw实现

@@ -318,8 +318,6 @@ class OptimizerNames(ExplicitEnum):
ADAFACTOR = "adafactor"
ADAMW_MINI = "adamw_mini"
ADAMW_CUSTOM = "adamw_custom"
ADAMW_16BIT_MOMENT = "adamw_16bit_moment"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两种在以往的配置中没有被使用到吗?

@@ -868,6 +868,10 @@ class TrainingArguments:
default="adamw",
metadata={"help": "The optimizer to use."},
)
use_lowprecision_moment: bool = field(
default=False,
metadata={"help": "AdamW use lowbit moment as parameter."},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的lowbit是指多少位,什么情况下建议开启,开启后的影响是什么,需要明确解释下

@@ -996,6 +1000,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Offload optimizer after optimizer.step()"},
)
tensorwise_offload_optimizer: Optional[bool] = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

help信息没解释清楚,为什么需要这个

@@ -445,7 +452,9 @@ def compute_metrics_do_generation(eval_preds):
gen_args=gen_args,
data_args=data_args,
)
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
trainable_parameters = [
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的hardcode可以避免吗?或者如何保证一定生效?至少需要有log提示

if self.weight_quantize_algo not in ["fp8linear", "a8w4linear", "fp8linear"]:
self.quant_scale.is_distributed = False
else:
self.quant_scale.is_distributed = True if self.is_mp else False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要考虑DP吗?

scale = paddle.max(paddle.abs(target_x)) / qmax
if group is not None:
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True)
if state < quantization_config.apply_online_actscale_step:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的online scaling是和delayed scaling对应的吗?不知道这个参数影响了什么?建议给用户解释下

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants